{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Importing relevant packages for finetuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'\n",
    "os.environ['JAX_PMAP_USE_TENSORSTORE'] = 'false'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import timesfm\n",
    "import gc\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from timesfm import patched_decoder\n",
    "from timesfm import data_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import dataclasses\n",
    "import IPython\n",
    "import IPython.display\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "mpl.rcParams['figure.figsize'] = (8, 6)\n",
    "mpl.rcParams['axes.grid'] = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading TimesFM pretrained checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timesfm_backend = \"gpu\"  # @param\n",
    "\n",
    "tfm = timesfm.TimesFm(\n",
    "      hparams=timesfm.TimesFmHparams(\n",
    "          backend=timesfm_backend,\n",
    "          per_core_batch_size=32,\n",
    "          horizon_len=128,\n",
    "          num_layers=50,\n",
    "          # Se this to True for v1.0 checkpoints\n",
    "          use_positional_embedding=False,\n",
    "          # Note that we could set this to as high as 2048 but keeping it 512 here so that\n",
    "          # both v1.0 and 2.0 checkpoints work\n",
    "          context_len=512,\n",
    "      ),\n",
    "      checkpoint=timesfm.TimesFmCheckpoint(\n",
    "          huggingface_repo_id=\"google/timesfm-2.0-500m-jax\"),\n",
    "  )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluating pretrained checkpoint on ETT datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DICT = {\n",
    "    \"ettm2\": {\n",
    "        \"boundaries\": [34560, 46080, 57600],\n",
    "        \"data_path\": \"../datasets/ETT-small/ETTm2.csv\",\n",
    "        \"freq\": \"15min\",\n",
    "    },\n",
    "    \"ettm1\": {\n",
    "        \"boundaries\": [34560, 46080, 57600],\n",
    "        \"data_path\": \"../datasets/ETT-small/ETTm1.csv\",\n",
    "        \"freq\": \"15min\",\n",
    "    },\n",
    "    \"etth2\": {\n",
    "        \"boundaries\": [8640, 11520, 14400],\n",
    "        \"data_path\": \"../datasets/ETT-small/ETTh2.csv\",\n",
    "        \"freq\": \"H\",\n",
    "    },\n",
    "    \"etth1\": {\n",
    "        \"boundaries\": [8640, 11520, 14400],\n",
    "        \"data_path\": \"../datasets/ETT-small/ETTh1.csv\",\n",
    "        \"freq\": \"H\",\n",
    "    },\n",
    "    \"elec\": {\n",
    "        \"boundaries\": [18413, 21044, 26304],\n",
    "        \"data_path\": \"../datasets/electricity/electricity.csv\",\n",
    "        \"freq\": \"H\",\n",
    "    },\n",
    "    \"traffic\": {\n",
    "        \"boundaries\": [12280, 14036, 17544],\n",
    "        \"data_path\": \"../datasets/traffic/traffic.csv\",\n",
    "        \"freq\": \"H\",\n",
    "    },\n",
    "    \"weather\": {\n",
    "        \"boundaries\": [36887, 42157, 52696],\n",
    "        \"data_path\": \"../datasets/weather/weather.csv\",\n",
    "        \"freq\": \"10min\",\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"ettm1\"\n",
    "data_path = DATA_DICT[dataset][\"data_path\"]\n",
    "freq = DATA_DICT[dataset][\"freq\"]\n",
    "int_freq = timesfm.freq_map(freq)\n",
    "boundaries = DATA_DICT[dataset][\"boundaries\"]\n",
    "\n",
    "data_df = pd.read_csv(open(data_path, \"r\"))\n",
    "\n",
    "\n",
    "ts_cols = [col for col in data_df.columns if col != \"date\"]\n",
    "num_cov_cols = None\n",
    "cat_cov_cols = None\n",
    "\n",
    "context_len = 512\n",
    "pred_len = 96\n",
    "\n",
    "num_ts = len(ts_cols)\n",
    "batch_size = 8\n",
    "\n",
    "dtl = data_loader.TimeSeriesdata(\n",
    "      data_path=data_path,\n",
    "      datetime_col=\"date\",\n",
    "      num_cov_cols=num_cov_cols,\n",
    "      cat_cov_cols=cat_cov_cols,\n",
    "      ts_cols=np.array(ts_cols),\n",
    "      train_range=[0, boundaries[0]],\n",
    "      val_range=[boundaries[0], boundaries[1]],\n",
    "      test_range=[boundaries[1], boundaries[2]],\n",
    "      hist_len=context_len,\n",
    "      pred_len=pred_len,\n",
    "      batch_size=num_ts,\n",
    "      freq=freq,\n",
    "      normalize=True,\n",
    "      epoch_len=None,\n",
    "      holiday=False,\n",
    "      permute=True,\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_batches = dtl.tf_dataset(mode=\"train\", shift=1).batch(batch_size)\n",
    "val_batches = dtl.tf_dataset(mode=\"val\", shift=pred_len)\n",
    "test_batches = dtl.tf_dataset(mode=\"test\", shift=pred_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for tbatch in tqdm(train_batches.as_numpy_iterator()):\n",
    "    break\n",
    "print(tbatch[0].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MAE on the test split for the pretrained TimesFM model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mae_losses = []\n",
    "for batch in tqdm(test_batches.as_numpy_iterator()):\n",
    "    past = batch[0]\n",
    "    actuals = batch[3]\n",
    "    forecasts, _ = tfm.forecast(list(past), [0] * past.shape[0], normalize=True)\n",
    "    forecasts = forecasts[:, 0 : actuals.shape[1]]\n",
    "    mae_losses.append(np.abs(forecasts - actuals).mean())\n",
    "\n",
    "print(f\"MAE: {np.mean(mae_losses)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finetuning the model on the ETT dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "from jax import numpy as jnp\n",
    "from praxis import pax_fiddle\n",
    "from praxis import py_utils\n",
    "from praxis import pytypes\n",
    "from praxis import base_model\n",
    "from praxis import optimizers\n",
    "from praxis import schedules\n",
    "from praxis import base_hyperparams\n",
    "from praxis import base_layer\n",
    "from paxml import tasks_lib\n",
    "from paxml import trainer_lib\n",
    "from paxml import checkpoints\n",
    "from paxml import learners\n",
    "from paxml import partitioning\n",
    "from paxml import checkpoint_types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PAX shortcuts\n",
    "NestedMap = py_utils.NestedMap\n",
    "WeightInit = base_layer.WeightInit\n",
    "WeightHParams = base_layer.WeightHParams\n",
    "InstantiableParams = py_utils.InstantiableParams\n",
    "JTensor = pytypes.JTensor\n",
    "NpTensor = pytypes.NpTensor\n",
    "WeightedScalars = pytypes.WeightedScalars\n",
    "instantiate = base_hyperparams.instantiate\n",
    "LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]\n",
    "AuxLossStruct = base_layer.AuxLossStruct\n",
    "\n",
    "AUX_LOSS = base_layer.AUX_LOSS\n",
    "template_field = base_layer.template_field\n",
    "\n",
    "# Standard prng key names\n",
    "PARAMS = base_layer.PARAMS\n",
    "RANDOM = base_layer.RANDOM\n",
    "\n",
    "key = jax.random.PRNGKey(seed=1234)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = pax_fiddle.Config(\n",
    "    patched_decoder.PatchedDecoderFinetuneModel,\n",
    "    name='patched_decoder_finetune',\n",
    "    core_layer_tpl=tfm.model_p,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### We will hold the transformer layers fixed while finetuning, while training all other components."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "@pax_fiddle.auto_config\n",
    "def build_learner() -> learners.Learner:\n",
    "  return pax_fiddle.Config(\n",
    "      learners.Learner,\n",
    "      name='learner',\n",
    "      loss_name='avg_qloss',\n",
    "      optimizer=optimizers.Adam(\n",
    "          epsilon=1e-7,\n",
    "          clip_threshold=1e2,\n",
    "          learning_rate=1e-2,\n",
    "          lr_schedule=pax_fiddle.Config(\n",
    "              schedules.Cosine,\n",
    "              initial_value=1e-3,\n",
    "              final_value=1e-4,\n",
    "              total_steps=40000,\n",
    "          ),\n",
    "          ema_decay=0.9999,\n",
    "      ),\n",
    "      # Linear probing i.e we hold the transformer layers fixed.\n",
    "      bprop_variable_exclusion=['.*/stacked_transformer_layer/.*'],\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_p = tasks_lib.SingleTask(\n",
    "    name='ts-learn',\n",
    "    model=model,\n",
    "    train=tasks_lib.SingleTask.Train(\n",
    "        learner=build_learner(),\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_p.model.ici_mesh_shape = [1, 1, 1]\n",
    "task_p.model.mesh_axis_names = ['replica', 'data', 'mdl']\n",
    "\n",
    "DEVICES = np.array(jax.devices()).reshape([1, 1, 1])\n",
    "MESH = jax.sharding.Mesh(DEVICES, ['replica', 'data', 'mdl'])\n",
    "\n",
    "num_devices = jax.local_device_count()\n",
    "print(f'num_devices: {num_devices}')\n",
    "print(f'device kind: {jax.local_devices()[0].device_kind}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jax_task = task_p\n",
    "key, init_key = jax.random.split(key)\n",
    "\n",
    "# To correctly prepare a batch of data for model initialization (now that shape\n",
    "# inference is merged), we take one devices*batch_size tensor tuple of data,\n",
    "# slice out just one batch, then run the prepare_input_batch function over it.\n",
    "\n",
    "\n",
    "def process_train_batch(batch):\n",
    "    past_ts = batch[0].reshape(batch_size * num_ts, -1)\n",
    "    actual_ts = batch[3].reshape(batch_size * num_ts, -1)\n",
    "    return NestedMap(input_ts=past_ts, actual_ts=actual_ts)\n",
    "\n",
    "\n",
    "def process_eval_batch(batch):\n",
    "    past_ts = batch[0]\n",
    "    actual_ts = batch[3]\n",
    "    return NestedMap(input_ts=past_ts, actual_ts=actual_ts)\n",
    "\n",
    "\n",
    "jax_model_states, _ = trainer_lib.initialize_model_state(\n",
    "    jax_task,\n",
    "    init_key,\n",
    "    process_train_batch(tbatch),\n",
    "    checkpoint_type=checkpoint_types.CheckpointType.GDA,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setting the initial model weights to the pretrained TimesFM parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jax_model_states.mdl_vars['params']['core_layer'] = tfm._train_state.mdl_vars['params']\n",
    "jax_vars = jax_model_states.mdl_vars\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "jax_task = task_p\n",
    "\n",
    "\n",
    "def train_step(states, prng_key, inputs):\n",
    "  return trainer_lib.train_step_single_learner(\n",
    "      jax_task, states, prng_key, inputs\n",
    "  )\n",
    "\n",
    "\n",
    "def eval_step(states, prng_key, inputs):\n",
    "  states = states.to_eval_state()\n",
    "  return trainer_lib.eval_step_single_learner(\n",
    "      jax_task, states, prng_key, inputs\n",
    "  )\n",
    "\n",
    "key, train_key, eval_key = jax.random.split(key, 3)\n",
    "train_prng_seed = jax.random.split(train_key, num=jax.local_device_count())\n",
    "eval_prng_seed = jax.random.split(eval_key, num=jax.local_device_count())\n",
    "\n",
    "p_train_step = jax.pmap(train_step, axis_name='batch')\n",
    "p_eval_step = jax.pmap(eval_step, axis_name='batch')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "replicated_jax_states = trainer_lib.replicate_model_state(jax_model_states)\n",
    "replicated_jax_vars = replicated_jax_states.mdl_vars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_eval_loss = 1e7\n",
    "step_count = 0\n",
    "patience = 0\n",
    "NUM_EPOCHS = 100\n",
    "PATIENCE = 5\n",
    "TRAIN_STEPS_PER_EVAL = 1000\n",
    "CHECKPOINT_DIR='/home/senrajat_google_com/ettm1_finetune'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reshape_batch_for_pmap(batch, num_devices):\n",
    "  def _reshape(input_tensor):\n",
    "    bsize = input_tensor.shape[0]\n",
    "    residual_shape = list(input_tensor.shape[1:])\n",
    "    nbsize = bsize // num_devices\n",
    "    return jnp.reshape(input_tensor, [num_devices, nbsize] + residual_shape)\n",
    "\n",
    "  return jax.tree.map(_reshape, batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(NUM_EPOCHS):\n",
    "    print(f\"__________________Epoch: {epoch}__________________\", flush=True)\n",
    "    train_its = train_batches.as_numpy_iterator()\n",
    "    if patience >= PATIENCE:\n",
    "        print(\"Early stopping.\", flush=True)\n",
    "        break\n",
    "    for batch in tqdm(train_its):\n",
    "        train_losses = []\n",
    "        if patience >= PATIENCE:\n",
    "            print(\"Early stopping.\", flush=True)\n",
    "            break\n",
    "        tbatch = process_train_batch(batch)\n",
    "        tbatch = reshape_batch_for_pmap(tbatch, num_devices)\n",
    "        replicated_jax_states, step_fun_out = p_train_step(\n",
    "            replicated_jax_states, train_prng_seed, tbatch\n",
    "        )\n",
    "        train_losses.append(step_fun_out.loss[0])\n",
    "        if step_count % TRAIN_STEPS_PER_EVAL == 0:\n",
    "            print(\n",
    "                f\"Train loss at step {step_count}: {np.mean(train_losses)}\",\n",
    "                flush=True,\n",
    "            )\n",
    "            train_losses = []\n",
    "            print(\"Starting eval.\", flush=True)\n",
    "            val_its = val_batches.as_numpy_iterator()\n",
    "            eval_losses = []\n",
    "            for ev_batch in tqdm(val_its):\n",
    "                ebatch = process_eval_batch(ev_batch)\n",
    "                ebatch = reshape_batch_for_pmap(ebatch, num_devices)\n",
    "                _, step_fun_out = p_eval_step(\n",
    "                    replicated_jax_states, eval_prng_seed, ebatch\n",
    "                )\n",
    "                eval_losses.append(step_fun_out.loss[0])\n",
    "            mean_loss = np.mean(eval_losses)\n",
    "            print(f\"Eval loss at step {step_count}: {mean_loss}\", flush=True)\n",
    "            if mean_loss < best_eval_loss or np.isnan(mean_loss):\n",
    "                best_eval_loss = mean_loss\n",
    "                print(\"Saving checkpoint.\")\n",
    "                jax_state_for_saving = py_utils.maybe_unreplicate_for_fully_replicated(\n",
    "                    replicated_jax_states\n",
    "                )\n",
    "                checkpoints.save_checkpoint(\n",
    "                    jax_state_for_saving, CHECKPOINT_DIR, overwrite=True\n",
    "                )\n",
    "                patience = 0\n",
    "                del jax_state_for_saving\n",
    "                gc.collect()\n",
    "            else:\n",
    "                patience += 1\n",
    "                print(f\"patience: {patience}\")\n",
    "        step_count += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading and evaluating the best (according to validation loss) finetuned checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_state = checkpoints.restore_checkpoint(jax_model_states, CHECKPOINT_DIR)\n",
    "print(train_state.step)\n",
    "tfm._train_state.mdl_vars['params'] = train_state.mdl_vars['params']['core_layer']\n",
    "tfm.jit_decode()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mae_losses = []\n",
    "for batch in tqdm(test_batches.as_numpy_iterator()):\n",
    "    past = batch[0]\n",
    "    actuals = batch[3]\n",
    "    _, forecasts = tfm.forecast(list(past), [0] * past.shape[0])\n",
    "    forecasts = forecasts[:, 0 : actuals.shape[1], 5]\n",
    "    mae_losses.append(np.abs(forecasts - actuals).mean())\n",
    "\n",
    "print(f\"MAE: {np.mean(mae_losses)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## There is around a __7%__ reduction in MAE from finetuning."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chronos-v2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
